from __future__ import absolute_import
from __future__ import print_function
import os
import sys
import numpy as np

sys.path.append('../tool')
import toolkits
import src.utils as ut
import datetime
import pandas as pd

#import pdb
# ===========================================
#        训练每个说话人的分数
# ===========================================
import argparse

parser = argparse.ArgumentParser()
# set up training configuration.
parser.add_argument('--gpu', default='', type=str)
parser.add_argument('--resume', default='../model/gvlad_softmax/resnet34_vlad8_ghost2_bdim512_deploy/weights.h5',
                    type=str)
parser.add_argument('--batch_size', default=16, type=int)
parser.add_argument('--data_path', default='../audio_test3_nosur', type=str)
# set up network configuration.
parser.add_argument('--net', default='resnet34s', choices=['resnet34s', 'resnet34l'], type=str)
parser.add_argument('--ghost_cluster', default=2, type=int)
parser.add_argument('--vlad_cluster', default=8, type=int)
parser.add_argument('--bottleneck_dim', default=512, type=int)
parser.add_argument('--aggregation_mode', default='gvlad', choices=['avg', 'vlad', 'gvlad'], type=str)
# set up learning rate, training loss and optimizer.
parser.add_argument('--loss', default='softmax', choices=['softmax', 'amsoftmax'], type=str)
parser.add_argument('--test_type', default='normal', choices=['normal', 'hard', 'extend'], type=str)

global args
args = parser.parse_args()


def main():
    # gpu configuration
    toolkits.initialize_GPU(args)

    # from src import model
    import model
    # ==================================
    #       Get Train/Val.
    # ==================================
    #print('==> calculating test({}) data lists...'.format(args.test_type))
   #写文件夹中所有wav与所有wav的txt文件
    txt_list = []
    for category in os.listdir(args.data_path):
        speaker_file = os.path.join(args.data_path, category)
        for category2 in os.listdir(speaker_file):
            speaker_wav1 = os.path.join(category, category2)
            for category3 in os.listdir(args.data_path):
                speaker_file2 = os.path.join(args.data_path, category3)
                for category4 in os.listdir(speaker_file2):
                    speaker_wav2 = os.path.join(category3, category4)
                    txt_list.append([1, speaker_wav1 , speaker_wav2])
    #print(txt_list)
    #print(len(txt_list))


    #verify_list = np.loadtxt('../meta/0226test.txt', str)
    #verify_list = np.loadtxt('../meta/test2.txt', str)
    #print(verify_list[0])
    #verify_lb = np.array([int(i[0]) for i in verify_list])
    #list1 = np.array([os.path.join(args.data_path, i[1]) for i in verify_list])
    #list2 = np.array([os.path.join(args.data_path, i[2]) for i in verify_list])
    list1 = np.array([os.path.join(args.data_path, i[1]) for i in txt_list])
    list2 = np.array([os.path.join(args.data_path, i[2]) for i in txt_list])
    #print(list1,list2)
    total_list = np.concatenate((list1, list2))
    unique_list = np.unique(total_list)
    #print(unique_list)
    #print(len(unique_list))
    unique_list_name = []
    for i in range(len(unique_list)):
        name_wav = unique_list[i].split('/')[3]
        name = name_wav.split('1')[0]
        unique_list_name.append(name)
    print(unique_list_name)



    # ==================================
    #       Get Model
    # ==================================
    # construct the data generator.
    params = {'dim': (257, None, 1),
              'nfft': 512,
              'spec_len': 250,
              'win_length': 400,
              'hop_length': 160,
              'n_classes': 5994,
              'sampling_rate': 16000,
              'normalize': True,
              }

    network_eval = model.vggvox_resnet2d_icassp(input_dim=params['dim'],
                                                num_class=params['n_classes'],
                                                mode='eval', args=args)

    # ==> load pre-trained model ???
    if args.resume:
        # ==> get real_model from arguments input,
        # load the model if the imag_model == real_model.
        if os.path.isfile(args.resume):
            network_eval.load_weights(os.path.join(args.resume), by_name=True)
            print('==> successfully loading model {}.'.format(args.resume))
        else:
            raise IOError("==> no checkpoint found at '{}'".format(args.resume))
    else:
        raise IOError('==> please type in the model to load')

    print('==> start testing.')

    # The feature extraction process has to be done sample-by-sample,
    # because each sample is of different lengths.
    total_length = len(unique_list)
    feats, scores, labels = [], [], []
    for c, ID in enumerate(unique_list):
        if c % 50 == 0: print('Finish extracting features for {}/{}th wav.'.format(c, total_length))
        specs = ut.load_data(ID, win_length=params['win_length'], sr=params['sampling_rate'],
                             hop_length=params['hop_length'], n_fft=params['nfft'],
                             spec_len=params['spec_len'], mode='eval')
        specs = np.expand_dims(np.expand_dims(specs, 0), -1)

        v = network_eval.predict(specs)
        feats += [v]

    feats = np.array(feats)

    # ==> compute the pair-wise similarity.

    end_list = []
    re_list = []
    for c, (p1, p2) in enumerate(zip(list1, list2)):
        ind1 = np.where(unique_list == p1)[0][0]
        ind2 = np.where(unique_list == p2)[0][0]
        #print(ind1,ind2)
        v1 = feats[ind1, 0]
        v2 = feats[ind2, 0]

        scores += [np.sum(v1 * v2)]
        print(unique_list[ind1],unique_list[ind2])
        print('scores : {}'.format(scores[-1]))
        re_list.append([ind1, ind2, scores[-1]])
        end_list.append([unique_list[ind1],unique_list[ind2], scores[-1]])
    #print(re_list)
    end_list_frame = pd.DataFrame(end_list)
    end_list_frame.to_csv('../result/no_sur/end_result.txt', header=0)

    re_frame = pd.DataFrame(re_list, columns =['first_wav' , 'second_wav' , 'score'])
    #print(re_frame['first_wav'].unique())
    re_frame = re_frame[re_frame['score'] < 0.99]
    re_frame = re_frame.reset_index(drop = True)

    name_score_list = []
    for i in range(len(re_frame)):
        first_wav_name = unique_list_name[re_frame['first_wav'][i]]
        second_wav_name = unique_list_name[re_frame['second_wav'][i]]
        score = re_frame['score'][i]
        name_score_list.append([first_wav_name, second_wav_name, score])
    #print(name_score_list)
     ##找敌人最大
    name_score_frame = pd.DataFrame(name_score_list, columns = ['first_wav_name', 'second_wav_name', 'score'])
    score_list_max = []
    for i in range(len(name_score_frame['first_wav_name'].unique())):
        name = name_score_frame['first_wav_name'].unique()[i]
        diff_name_frame = name_score_frame[(name_score_frame['first_wav_name'] == name) & (name_score_frame['second_wav_name'] != name)]
        max_diff_score = max(diff_name_frame['score'])
        print(max_diff_score)
        score_list_max.append([name, max_diff_score])

    score_frame_max = pd.DataFrame(score_list_max)
    score_frame_max.to_csv('../result/no_sur/score_list_max.txt', header=0)

    ##找自己最小
    name_score_frame = pd.DataFrame(name_score_list, columns=['first_wav_name', 'second_wav_name', 'score'])
    score_list_min = []
    for i in range(len(name_score_frame['first_wav_name'].unique())):
        name = name_score_frame['first_wav_name'].unique()[i]
        diff_name_frame = name_score_frame[
            (name_score_frame['first_wav_name'] == name) & (name_score_frame['second_wav_name'] == name)]
        max_diff_score = min(diff_name_frame['score'])
        score_list_min.append([name, max_diff_score])

    score_frame_min = pd.DataFrame(score_list_min)
    score_frame_min.to_csv('../result/no_sur/score_list_min.txt', header=0)

    #找自己90%的最小分数，假设有100个分数，排序，取排90的分数，只要大于这个分数，就认为是这个人。
    name_score_frame = pd.DataFrame(name_score_list, columns=['first_wav_name', 'second_wav_name', 'score'])
    score_list90 = []
    for i in range(len(name_score_frame['first_wav_name'].unique())):
        name = name_score_frame['first_wav_name'].unique()[i]
        diff_name_frame = name_score_frame[
            (name_score_frame['first_wav_name'] == name) & (name_score_frame['second_wav_name'] == name)]
        rank_90 = int(0.9*(len(diff_name_frame)))
        diff_name_frame = diff_name_frame.reset_index(drop=True)
        diff_name_frame = diff_name_frame.sort_values(ascending=False, by='score')#sort
        print(diff_name_frame)
        max_diff_score = diff_name_frame['score'][rank_90]
        score_list90.append([name, max_diff_score])

    score_frame90 = pd.DataFrame(score_list90)
    score_frame90.to_csv('../result/no_sur/score_list90.txt', header = 0)

    # 找第三四分3位数，，只要大于这个分数，就认为是这个人。
    name_score_frame = pd.DataFrame(name_score_list, columns=['first_wav_name', 'second_wav_name', 'score'])
    score_list75 = []
    for i in range(len(name_score_frame['first_wav_name'].unique())):
        name = name_score_frame['first_wav_name'].unique()[i]
        diff_name_frame = name_score_frame[
            (name_score_frame['first_wav_name'] == name) & (name_score_frame['second_wav_name'] == name)]
        rank_75 = int(0.75 * (len(diff_name_frame)))
        diff_name_frame = diff_name_frame.reset_index(drop=True)
        diff_name_frame = diff_name_frame.sort_values(ascending=False, by='score')  # sort
        print(diff_name_frame)
        max_diff_score = diff_name_frame['score'][rank_75]
        score_list75.append([name, max_diff_score])

    score_frame75 = pd.DataFrame(score_list75)
    score_frame75.to_csv('../result/no_sur/score_list75.txt', header=0)




if __name__ == "__main__":
    starttime = datetime.datetime.now()
    main()
    endtime = datetime.datetime.now()
    print((endtime - starttime).seconds, 's')